## A user-defined function to implement the equations of the LSTM model
## See " LSTM-based model predictive control with discrete inputs for irrigation scheduling" by 
## Agyeman B.T., Sahoo R.S, Liu J., and Shah S.L. for details of the LSTM equations.

#Imports
import  numpy as np


#Function that implements the sigmoid function
def SigmoidFunction(x):
    return 1/(1+np.exp(-x))

#LSTM model equations
def ObtainLstmModel(kernel_1, recurrentKernel_1, bias_1,kernel_2, recurrentKernel_2, bias_2, kernel_dense, bias_dense,units_1, units_2, sequenceLength, inputs):
    #Layer 1
    
    #The kernel matrices for the gates and the cell state
    W_i_1 = kernel_1[:,:units_1] # The input gate
    W_f_1 = kernel_1[:,units_1: units_1*2] # The forget gate
    W_c_1 = kernel_1[:, units_1*2:units_1*3] # The cell state
    W_o_1 = kernel_1[:, units_1*3:] # The output gate

    #The recurrent kernel matrices for the gates and the cell state

    U_i_1 = recurrentKernel_1[:,:units_1] # The input gate
    U_f_1 = recurrentKernel_1[:,units_1:units_1*2] # The forget gate
    U_c_1 = recurrentKernel_1[:, units_1*2:units_1*3] # The cell state
    U_o_1 = recurrentKernel_1[:, units_1*3:] # The output gate


    # The bias matrices for the gates and the cell state
    b_i_1 = bias_1[:units_1] # The input gate
    b_f_1 = bias_1[units_1:units_1*2] # The forget gate
    b_c_1 = bias_1[units_1*2:units_1*3] #The cell state
    b_o_1 = bias_1[units_1*3:] # The output gate
    

    #Layer 2
    #The kernel matrices for the gates and the cell state
    W_i_2 = kernel_2[:,:units_2] # The input gate
    W_f_2 = kernel_2[:,units_2: units_2*2] # The forget gate
    W_c_2 = kernel_2[:, units_2*2:units_2*3] # The cell state
    W_o_2 = kernel_2[:, units_2*3:] # The output gate

    #The recurrent kernel matrices for the gates and the cell state

    U_i_2 = recurrentKernel_2[:,:units_2] # The input gate
    U_f_2 = recurrentKernel_2[:,units_2:units_2*2] # The forget gate
    U_c_2 = recurrentKernel_2[:, units_2*2:units_2*3] # The cell state
    U_o_2 = recurrentKernel_2[:, units_2*3:] # The output gate


    # The bias matrices for the gates and the cell state
    b_i_2 = bias_2[:units_2] # The input gate
    b_f_2 = bias_2[units_2:units_2*2] # The forget gate
    b_c_2 = bias_2[units_2*2:units_2*3] #The cell state
    b_o_2 = bias_2[units_2*3:] # The output gate
    
    
    h_t_init_1 = np.zeros(units_1)
    c_t_init_1 = np.zeros(units_1)
    
    h_t_init_2 = np.zeros(units_2)
    c_t_init_2 = np.zeros(units_2)
    
    for j in range(sequenceLength):
        f_1 = np.matmul(np.transpose(W_f_1), inputs[j])  + np.matmul(np.transpose(U_f_1), h_t_init_1) + np.transpose(b_f_1)
        f_t_1 = SigmoidFunction(f_1)
        
        i_1 = np.matmul(np.transpose(W_i_1), inputs[j])  + np.matmul(np.transpose(U_i_1), h_t_init_1) + np.transpose(b_i_1)
        i_t_1 = SigmoidFunction(i_1)
        
        o_1 = np.matmul(np.transpose(W_o_1), inputs[j])  + np.matmul(np.transpose(U_o_1), h_t_init_1) + np.transpose(b_o_1)
        o_t_1 = SigmoidFunction(o_1)
        
        c_1 = np.matmul(np.transpose(W_c_1), inputs[j])  + np.matmul(np.transpose(U_c_1), h_t_init_1) + np.transpose(b_c_1)
        c_t_1 = np.tanh(c_1)
        
        C_t_1=np.multiply(f_t_1, c_t_init_1) + np.multiply(i_t_1, c_t_1)
        
        h_t_1 = np.multiply(o_t_1, np.tanh(C_t_1))
        
        
        f_2 = np.matmul(np.transpose(W_f_2), h_t_1)  + np.matmul(np.transpose(U_f_2), h_t_init_2) + np.transpose(b_f_2)
        f_t_2 = SigmoidFunction(f_2)
        
        i_2 = np.matmul(np.transpose(W_i_2), h_t_1)  + np.matmul(np.transpose(U_i_2), h_t_init_2) + np.transpose(b_i_2)
        i_t_2 = SigmoidFunction(i_2)
        
        o_2 = np.matmul(np.transpose(W_o_2), h_t_1)  + np.matmul(np.transpose(U_o_2), h_t_init_2) + np.transpose(b_o_2)
        o_t_2 = SigmoidFunction(o_2)
        
        c_2 = np.matmul(np.transpose(W_c_2), h_t_1)  + np.matmul(np.transpose(U_c_2), h_t_init_2) + np.transpose(b_c_2)
        c_t_2 = np.tanh(c_2)
        
        C_t_2=np.multiply(f_t_2, c_t_init_2) + np.multiply(i_t_2, c_t_2)
        
        h_t_2 = np.multiply(o_t_2, np.tanh(C_t_2))
        
        
        h_t_init_1 = h_t_1
        c_t_init_1 = C_t_1
        
        h_t_init_2 = h_t_2
        c_t_init_2 = C_t_2
        
    O_t=np.transpose(bias_dense) + np.matmul(np.transpose(kernel_dense), h_t_2)
    
    return O_t